{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensorflow Version: 2.0.0-rc0\n"
     ]
    }
   ],
   "source": [
    "print('Tensorflow Version: {}'.format(tf.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_image, train_lable), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 28, 28)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_lable.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10000, 28, 28), (10000,))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_image.shape, test_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x1f6687ee3c8>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAUDklEQVR4nO3da2yc1ZkH8P8z4/ElzjiJc3FCcAmXUJLCEqhJgFSUkkJDtNqQUioQYkFCG7QL3bbLBxDtquyXFUILCC277RrIElaFqlVBUBRRgrlkgZLGhJTcNgQSk5tjOzGxHcdjz+XZDx5aE3ye18w7M+/A+f8ky/Y8PjPHM/77nZnznnNEVUFEX36xqDtAROXBsBN5gmEn8gTDTuQJhp3IE1XlvLFqqdFa1JfzJom8ksIgRnRYxquFCruILAfwMIA4gMdU9T7r52tRjyWyLMxNEpFho7Y5awU/jReROID/AHA1gIUAbhCRhYVeHxGVVpjX7IsBfKCqe1R1BMCvAKwsTreIqNjChH0ugP1jvj+Qv+xTRGS1iLSLSHsawyFujojCCBP28d4E+My5t6raqqotqtqSQE2ImyOiMMKE/QCA5jHfnwrgULjuEFGphAn7JgDzReR0EakGcD2A54vTLSIqtoKH3lQ1IyJ3APg9Rofe1qjq9qL1jIiKKtQ4u6quA7CuSH0hohLi6bJEnmDYiTzBsBN5gmEn8gTDTuQJhp3IEww7kScYdiJPMOxEnmDYiTzBsBN5gmEn8gTDTuSJsi4lTRGQcVcV/ouQG3vGpzea9Y+/c7az1vDU26FuO+h3k6qEs6bpkXC3HVbQ42Ip8DHjkZ3IEww7kScYdiJPMOxEnmDYiTzBsBN5gmEn8gTH2b/kJB4365rJmPXYInuvzp23TbbbD7lricHFZtuqoZxZT7zUbtZDjaUHjeEH3K8Q+zgapm9SZcTWeDh5ZCfyBMNO5AmGncgTDDuRJxh2Ik8w7ESeYNiJPMFx9i85c0wWwePs+78z1azfeMn/mvU3e85w1j6qmW221TqzjKpvX2LWz/7Pg85apmOffeUBc8aD7rcg8WnT3MVs1myb7e93F41uhwq7iHQAGACQBZBR1ZYw10dEpVOMI/u3VPVIEa6HiEqIr9mJPBE27ArgJRF5R0RWj/cDIrJaRNpFpD2N4ZA3R0SFCvs0fqmqHhKRWQDWi8j/qeqGsT+gqq0AWgGgQRrDrW5IRAULdWRX1UP5z90AngVgT2MiosgUHHYRqReR5CdfA7gKwLZidYyIiivM0/gmAM/K6LzfKgBPqeqLRekVFU0ulQrVfuSC42b9e1PsOeW1sbSz9nrMnq9+8JVms579K7tvHz2YdNZy715qtp2+zR7rbni306wfuWyuWe/5uvsVbVPAcvrTXv7QWZNed6QLDruq7gFwfqHtiai8OPRG5AmGncgTDDuRJxh2Ik8w7ESeEA25Ze/n0SCNukSWle32vGEtexzw+B7//sVm/eqfvmbWF9QeMusDuVpnbUTDncD5yK5vmvXBPVOctdhIwJbJAeVsk70UtKbt4+i0ze7fvW5ll9lWHp3prL3X9jCO9+4ft/c8shN5gmEn8gTDTuQJhp3IEww7kScYdiJPMOxEnuA4eyUI2B44lIDH99x37P/3351mT2ENEjfWNh7UarPtsWx9qNvuybinuKYDxvgf221PgT1ujOEDQCxjP6ZXfutdZ+3axk1m2/vPPM9Z26ht6NdejrMT+YxhJ/IEw07kCYadyBMMO5EnGHYiTzDsRJ7gls2VoIznOpxs9/FZZv1ow2Szfjhjb+k8Pe5e7jkZGzLbzkvY+4X2ZN3j6AAQT7iXqh7RuNn2X772O7OeWpAw6wmxl6K+1FgH4Lodf2u2rcces+7CIzuRJxh2Ik8w7ESeYNiJPMGwE3mCYSfyBMNO5AmOs3tuZo297XGtuLdcBoBqyZj1Q+lpztruoa+abd/vt88BWN603aynjbF0a549EDxOfkriY7OeUnsc3rpXlzbZ4+hbzKpb4JFdRNaISLeIbBtzWaOIrBeR3fnP7keUiCrCRJ7GPwFg+UmX3Q2gTVXnA2jLf09EFSww7Kq6AUDvSRevBLA2//VaANcUuV9EVGSFvkHXpKqdAJD/7HxxJSKrRaRdRNrTGC7w5ogorJK/G6+qraraoqotCdSU+uaIyKHQsHeJyBwAyH/uLl6XiKgUCg378wBuzn99M4DnitMdIiqVwHF2EXkawOUAZojIAQA/A3AfgF+LyK0A9gG4rpSd/NILWDde4vbca824x7rj0+xR0W9O3WrWe7INZv1YdpJZnxo/4awNZNx7twNA75B93efUdJr1zSfmOWszq+1xcqvfANAxMsOsz685bNbv73Lvn9Bce/L74Z+WWXaZs6Yb/+CsBYZdVW9wlLjbA9EXCE+XJfIEw07kCYadyBMMO5EnGHYiT3CKayUIWEpaquyHyRp623/rArPtFZPsJZPfSs016zOrBsy6Nc10Tk2f2TbZlDLrQcN+jVXu6bsD2Tqz7aSYfWp30O99YbW9DPaPX77QWUuee9Rs25AwjtHGKC6P7ESeYNiJPMGwE3mCYSfyBMNO5AmGncgTDDuRJzjOXgEkUW3Wcyl7vNkyY+uIWT+StZc8nhqzp3pWByy5bG2NfGnjXrNtT8BY+Oah0816Mu7eEnpmzB4nb07YY91bU81mfd3gWWb91r9+2Vl7uvVKs231i285a6Lux4tHdiJPMOxEnmDYiTzBsBN5gmEn8gTDTuQJhp3IE1+scXZjyWWpsseLJR7wfy1m13MpY35zzh5rDqJpeyw8jIf/6xGzvj8z1awfTtv1oCWXs8YE67eHpphta2P2dtEzq/rNen/OHqe3DOTsZa6tefpAcN/vmr7bWXum79tm20LxyE7kCYadyBMMO5EnGHYiTzDsRJ5g2Ik8wbATeaKixtnDrI8eNFat9rBnpIZWLjbr+6+xx/FvvOCPztrhTNJs+66xrTEATDHmhANAfcD66il1n/9waMTeTjporNpaFx4AZhnj8Fm1j3MH03bfggSdf3AgY6xp/zf2XPupTxbUpeAju4isEZFuEdk25rJ7ReSgiGzJf6wo7OaJqFwm8jT+CQDLx7n8IVVdlP9YV9xuEVGxBYZdVTcA6C1DX4iohMK8QXeHiLyXf5rvfIEjIqtFpF1E2tOwX98RUekUGvafAzgTwCIAnQAecP2gqraqaouqtiRQU+DNEVFYBYVdVbtUNauqOQCPArDfTiaiyBUUdhGZM+bbVQC2uX6WiCpD4Di7iDwN4HIAM0TkAICfAbhcRBYBUAAdAG4rRmescfSwqubMNuvp05vMeu8C917gJ2Ybm2IDWLRip1m/pem/zXpPtsGsJ8TYnz093Wx7waQOs/5K30KzfqRqslm3xukvrXfP6QaAYzl7//VTqj4263d98D1nrWmSPZb92Gn2AFNac2Z9V9p+ydqXc8+H/8eFr5ptn8VMs+4SGHZVvWGcix8v6NaIKDI8XZbIEww7kScYdiJPMOxEnmDYiTxRUVNch6++yKzP+skeZ21RwwGz7cK6N8x6KmcvRW1Nt9wxNNdseyJnb8m8e8QeFuzL2ENQcXEPA3WP2FNcH9hrL1vctvgXZv2nh8abI/UXsTp11o5m7WG7ayfbS0UD9mN221c2OGtnVHebbV8YnGPWDwVMgW1K9Jn1eYkeZ+27yffNtoUOvfHITuQJhp3IEww7kScYdiJPMOxEnmDYiTzBsBN5orzj7GIvF73kXzeZzZcltztrJ9SeUhg0jh40bmqZUmUvGzyctu/m7rQ9hTXI2TWHnbVVDVvMthseWWLWv5H6gVn/8Ap7em7bkHsqZ0/G/r2v33uFWd+8r9msXzxvr7N2XvKg2Tbo3IZkPGXWrWnHADCYc/+9vp2yzz8oFI/sRJ5g2Ik8wbATeYJhJ/IEw07kCYadyBMMO5EnRNU937jY6mY365k3/ZOz3nr7v5vtn+q92FlrrrW3ozut+ohZnx63t/+1JGP2mOtXE/aY6wuDp5r1146dY9a/nuxw1hJib/d8+aQPzPotP77TrGdq7WW0++e5jyeZevtvr+H8o2b9B2e9Ytarjd/9WNYeRw+634K2ZA5irUGQjNnbZD+wYpWz9oeOJ9A31Dnug8IjO5EnGHYiTzDsRJ5g2Ik8wbATeYJhJ/IEw07kibLOZ4+lgUld7vHFF/oXme3PqHOvtX0kba+P/vvj55n1U+vs7X+trYfPMuaTA8CW1FSz/mLP18z6KXX2+uld6SnO2tF0vdn2hDGvGgAef+hBs/5Al73u/KrGzc7a+dX2OPqxnH0s2hGw3v5ArtZZS6m9vkFfwDh80vh7AIC02tGKG1s+T43ZY/j957m34c52uW838MguIs0i8qqI7BSR7SLyw/zljSKyXkR25z8XvvoDEZXcRJ7GZwDcqaoLAFwM4HYRWQjgbgBtqjofQFv+eyKqUIFhV9VOVd2c/3oAwE4AcwGsBLA2/2NrAVxTqk4SUXif6w06EZkH4AIAGwE0qWonMPoPAcAsR5vVItIuIu2Z4cFwvSWigk047CIyGcBvAfxIVYN23PszVW1V1RZVbamqsd8sIqLSmVDYRSSB0aD/UlWfyV/cJSJz8vU5AOxtMYkoUoFDbyIiAB4HsFNVx47DPA/gZgD35T8/F3Rd8ZEckvuHnfWc2tMlXzninurZVDtgtl2U3G/Wd52wh3G2Dp3irG2u+orZti7u3u4ZAKZU21Nk66vc9xkAzEi4f/fTa+z/wdY0UADYlLJ/t7+f+ZpZ35dxD9L8bvBss+2OE+77HACmBSzhvbXf3f5Ext5GezhrRyOVsYdyp9TYj+lFjR85a7tgbxfdc74xbfhNd7uJjLMvBXATgK0i8ski5PdgNOS/FpFbAewDcN0ErouIIhIYdlV9A4DrkLusuN0holLh6bJEnmDYiTzBsBN5gmEn8gTDTuSJ8m7ZfHwIsdffdZZ/89JSs/k/r/yNs/Z6wHLLLxy2x0X7R+ypnjMnuU/1bTDGuQGgMWGfJhy05XNtwPa/H2fcZyYOx+ypnFnnQMuow8Pu6bMA8GZuvllP59xbNg8bNSD4/ITekRlm/ZS6PmdtIOOe/goAHQONZv1In72tcmqSHa03smc6a8tnu7cmB4C6bvdjFjP+VHhkJ/IEw07kCYadyBMMO5EnGHYiTzDsRJ5g2Ik8UdYtmxukUZdI4RPl+m50b9l8xj/sMtsunrrXrG/ut+dt7zPGXdMBSx4nYu5lgwFgUmLErNcGjDdXx91z0mOwH99cwDh7fdzuW9Bc+4Yq97zuZNye8x0ztjWeiLjxu/+xb16o604G/N4Ztf8mLpnyobO2Zu+lZtspK9zbbG/UNvRrL7dsJvIZw07kCYadyBMMO5EnGHYiTzDsRJ5g2Ik8Uf5x9vhV7h/I2WuYhzF47RKzvuSeTXY96R4XPae6y2ybgD1eXBswnlwfs8fCU8ZjGPTf/I2hZrOeDbiGVz5eYNbTxnhz14kGs23COH9gIqx9CIYyAVs2D9nz3eMxOzep1+y59tN3uM+dqFln/y1aOM5ORAw7kS8YdiJPMOxEnmDYiTzBsBN5gmEn8kTgOLuINAN4EsBsADkArar6sIjcC+DvAPTkf/QeVV1nXVfY+eyVSi6y16Qfml1n1muO2nOjB06z2zd86F6XPjZsrzmf+9NOs05fLNY4+0Q2icgAuFNVN4tIEsA7IrI+X3tIVf+tWB0lotKZyP7snQA6818PiMhOAHNL3TEiKq7P9ZpdROYBuADAxvxFd4jIeyKyRkSmOdqsFpF2EWlPw366SkSlM+Gwi8hkAL8F8CNV7QfwcwBnAliE0SP/A+O1U9VWVW1R1ZYE7P3UiKh0JhR2EUlgNOi/VNVnAEBVu1Q1q6o5AI8CWFy6bhJRWIFhFxEB8DiAnar64JjL54z5sVUAthW/e0RULBN5N34pgJsAbBWRLfnL7gFwg4gsAqAAOgDcVpIefgHopq1m3Z4sGazhrcLbhluMmb5MJvJu/BvAuIuLm2PqRFRZeAYdkScYdiJPMOxEnmDYiTzBsBN5gmEn8gTDTuQJhp3IEww7kScYdiJPMOxEnmDYiTzBsBN5gmEn8kRZt2wWkR4AH425aAaAI2XrwOdTqX2r1H4B7Fuhitm301R15niFsob9Mzcu0q6qLZF1wFCpfavUfgHsW6HK1Tc+jSfyBMNO5Imow94a8e1bKrVvldovgH0rVFn6FulrdiIqn6iP7ERUJgw7kSciCbuILBeRXSLygYjcHUUfXESkQ0S2isgWEWmPuC9rRKRbRLaNuaxRRNaLyO7853H32Iuob/eKyMH8fbdFRFZE1LdmEXlVRHaKyHYR+WH+8kjvO6NfZbnfyv6aXUTiAN4HcCWAAwA2AbhBVXeUtSMOItIBoEVVIz8BQ0QuA3AcwJOqem7+svsB9Krqffl/lNNU9a4K6du9AI5HvY13freiOWO3GQdwDYBbEOF9Z/Tr+yjD/RbFkX0xgA9UdY+qjgD4FYCVEfSj4qnqBgC9J128EsDa/NdrMfrHUnaOvlUEVe1U1c35rwcAfLLNeKT3ndGvsogi7HMB7B/z/QFU1n7vCuAlEXlHRFZH3ZlxNKlqJzD6xwNgVsT9OVngNt7ldNI24xVz3xWy/XlYUYR9vK2kKmn8b6mqXgjgagC355+u0sRMaBvvchlnm/GKUOj252FFEfYDAJrHfH8qgEMR9GNcqnoo/7kbwLOovK2ouz7ZQTf/uTvi/vxZJW3jPd4246iA+y7K7c+jCPsmAPNF5HQRqQZwPYDnI+jHZ4hIff6NE4hIPYCrUHlbUT8P4Ob81zcDeC7CvnxKpWzj7dpmHBHfd5Fvf66qZf8AsAKj78h/COAnUfTB0a8zAPwp/7E96r4BeBqjT+vSGH1GdCuA6QDaAOzOf26soL79D4CtAN7DaLDmRNS3b2D0peF7ALbkP1ZEfd8Z/SrL/cbTZYk8wTPoiDzBsBN5gmEn8gTDTuQJhp3IEww7kScYdiJP/D866iIlQ3gtyAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(train_image[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "255"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(train_image[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_lable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image = train_image/255\n",
    "test_image = test_image/255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 28, 28)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Flatten(input_shape=(28,28)))  # 28*28\n",
    "model.add(tf.keras.layers.Dense(128, activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['acc']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples\n",
      "Epoch 1/3\n",
      "60000/60000 [==============================] - 5s 82us/sample - loss: 0.5020 - acc: 0.8239\n",
      "Epoch 2/3\n",
      "60000/60000 [==============================] - 4s 68us/sample - loss: 0.3754 - acc: 0.8650\n",
      "Epoch 3/3\n",
      "60000/60000 [==============================] - 4s 68us/sample - loss: 0.3381 - acc: 0.8757\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x28fb0751cc0>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_image, train_lable, epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.3684553307771683, 0.8659]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 保存整个模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "整个模型可以保存到一个文件中，其中包含权重值、模型配置乃至优化器配置。这样，您就可以为模型设置检查点，并稍后从完全相同的状态继续训练，而无需访问原始代码。\n",
    "\n",
    "在 Keras 中保存完全可正常使用的模型非常有用，您可以在 TensorFlow.js 中加载它们，然后在网络浏览器中训练和运行它们。\n",
    "\n",
    "Keras 使用 HDF5 标准提供基本的保存格式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('less_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0930 19:59:17.244999 13052 deprecation.py:323] From c:\\users\\guanghua\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\tensorflow_core\\python\\ops\\math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    }
   ],
   "source": [
    "new_model = tf.keras.models.load_model('less_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "new_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.3684553307771683, 0.8659]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此方法保存以下所有内容：\n",
    "\n",
    "1.权重值\n",
    "2.模型配置（架构）\n",
    "3.优化器配置"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 仅保存架构"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "有时我们只对模型的架构感兴趣，而无需保存权重值或优化器。在这种情况下，可以仅保存模型的“配置” 。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_config = model.to_json()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"class_name\": \"Sequential\", \"config\": {\"name\": \"sequential\", \"layers\": [{\"class_name\": \"Flatten\", \"config\": {\"name\": \"flatten\", \"trainable\": true, \"batch_input_shape\": [null, 28, 28], \"dtype\": \"float32\", \"data_format\": \"channels_last\"}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 128, \"activation\": \"relu\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_1\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 10, \"activation\": \"softmax\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}]}, \"keras_version\": \"2.2.4-tf\", \"backend\": \"tensorflow\"}'"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "reinitialized_model = tf.keras.models.model_from_json(json_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "reinitialized_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "reinitialized_model.compile(optimizer='adam',\n",
    "                           loss='sparse_categorical_crossentropy',\n",
    "                           metrics=['acc']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2.5259773368835448, 0.0712]"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reinitialized_model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 仅保存权重"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "有时我们只需要保存模型的状态（其权重值），而对模型架构不感兴趣。在这种情况下，可以通过get_weights()获取权重值，并通过set_weights()设置权重值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "weighs = model.get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "reinitialized_model.set_weights(weighs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.3684553307771683, 0.8659]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reinitialized_model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('less_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "reinitialized_model.load_weights('less_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.3684553307771683, 0.8659]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reinitialized_model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 在训练期间保存检查点"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在训练期间或训练结束时自动保存检查点。这样一来，您便可以使用经过训练的模型，而无需重新训练该模型，或从上次暂停的地方继续训练，以防训练过程中断。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "回调函数：tf.keras.callbacks.ModelCheckpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = 'training_cp/cp.ckpt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,\n",
    "                                                 save_weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Flatten(input_shape=(28,28)))  # 28*28\n",
    "model.add(tf.keras.layers.Dense(128, activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['acc']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2.4373130966186523, 0.0964]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1b2349514a8>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_weights(checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.38093261218070984, 0.8623]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_image, test_label, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples\n",
      "Epoch 1/3\n",
      "60000/60000 [==============================] - 5s 86us/sample - loss: 0.5020 - acc: 0.8227\n",
      "Epoch 2/3\n",
      "60000/60000 [==============================] - 4s 72us/sample - loss: 0.3776 - acc: 0.8643\n",
      "Epoch 3/3\n",
      "60000/60000 [==============================] - 4s 73us/sample - loss: 0.3392 - acc: 0.8749\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x268b4117b00>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_image, train_lable, epochs=3, callbacks=[cp_callback])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义训练中保存检查点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Flatten(input_shape=(28,28)))  # 28*28\n",
    "model.add(tf.keras.layers.Dense(128, activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(model, x, y):\n",
    "    y_ = model(x)\n",
    "    return loss_func(y, y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(model, images, labels):\n",
    "    with tf.GradientTape() as t:\n",
    "        pred = model(images)\n",
    "        loss_step = loss_func(labels, pred)\n",
    "    grads = t.gradient(loss_step, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "    train_loss(loss_step)\n",
    "    train_accuracy(labels, pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')\n",
    "test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "cp_dir = './customtrain_cp'\n",
    "cp_prefix = os.path.join(cp_dir, 'ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
    "                                 model=model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((train_image, train_lable))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.shuffle(10000).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    for epoch in range(5):\n",
    "        for (batch, (images, labels)) in enumerate(dataset):\n",
    "            train_step(model, images, labels)\n",
    "        print('Epoch{} loss is {}'.format(epoch, train_loss.result()))\n",
    "        print('Epoch{} Accuracy is {}'.format(epoch, train_accuracy.result()))\n",
    "        train_loss.reset_states()\n",
    "        train_accuracy.reset_states()\n",
    "        if (epoch + 1) % 2 == 0:\n",
    "            checkpoint.save(file_prefix = cp_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1001 19:38:17.738456 12396 base_layer.py:1772] Layer flatten is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch0 loss is 0.49911823868751526\n",
      "Epoch0 Accuracy is 0.8250499963760376\n",
      "Epoch1 loss is 0.3733137547969818\n",
      "Epoch1 Accuracy is 0.8652499914169312\n",
      "Epoch2 loss is 0.33668026328086853\n",
      "Epoch2 Accuracy is 0.8779500126838684\n",
      "Epoch3 loss is 0.3103724718093872\n",
      "Epoch3 Accuracy is 0.886816680431366\n",
      "Epoch4 loss is 0.29314786195755005\n",
      "Epoch4 Accuracy is 0.8917333483695984\n"
     ]
    }
   ],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1f64470c6d8>"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint.restore(tf.train.latest_checkpoint(cp_dir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9, 0, 0, ..., 3, 0, 5], dtype=int64)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.argmax(model(train_image, training=False), axis=-1).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_lable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8905166666666666"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(tf.argmax(model(train_image, training=False), axis=-1).numpy() == train_lable).sum()/len(train_lable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "恢复模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
